{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0373c4a-e565-4e8f-a87f-aae932d3aeed",
   "metadata": {
    "id": "b0373c4a-e565-4e8f-a87f-aae932d3aeed"
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GitHub\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n",
    "\n",
    "\n",
    "NOTE: User is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use.\n",
    "\"\"\"\n",
    "# If you're using Google Colab and not running locally, run this cell.\n",
    "import os\n",
    "\n",
    "# Install dependencies\n",
    "!pip install wget\n",
    "!apt-get install sox libsndfile1 ffmpeg\n",
    "!pip install text-unidecode\n",
    "!pip install matplotlib>=3.3.2\n",
    "\n",
    "## Install NeMo\n",
    "BRANCH = 'main'\n",
    "!python -m pip install \"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@$BRANCH\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c021f07-0576-491d-b73c-6c65c8501351",
   "metadata": {
    "id": "6c021f07-0576-491d-b73c-6c65c8501351"
   },
   "source": [
    "# Multi Task Adaptation with Adapters\n",
    "\n",
    "\n",
    "In earlier tutorials, we utilized a specific model for one task - for example, an ASR model (CTC, RNN-T etc) for the singular task of Speech Recognition. This is very useful if we want to specialize one task per model, but it can be expensive to deploy a fleet of models for each task, and learn routers to pass user tasks to correct models.\n",
    "\n",
    "We now support Multi Task models in NeMo, such that a single model can perform multiple tasks such as speech recognition, speech translation, voice activity detection, and more in the future. With one model supporting multiple tasks, we can simplify the task of deploying models and also hope to leverage individual tasks to improve each other (for example: you do need strong speech recognition first before you start doing translation).\n",
    "\n",
    "---\n",
    "\n",
    "Multi Task (Canary) models are highly capable large neural networks capable of things like speech recognition, X to English and English to X translation and able to select whether to transcribe speech with punctuation and capitalization. These huge models are trained on several thousand hours of speech and text data, making it challenging to adapt to new datasets.\n",
    "\n",
    "In the previous tutorial for [ASR Adapters](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb), we used small adapter modules to tune a large ASR model on a small amount of data. In this tutorial, we will adapt a [Nvidia Canary](https://huggingface.co/nvidia/canary-1b) model onto a small amount of speech data for both Automatic Speech Recognition (ASR) and Automatic Speech Translation (AST).\n",
    "\n",
    "In this tutorial, we will also demonstrate a simple way of creating custom Data Modules from PyTorch Lightning to design custom datasets and data loaders for the highly flexible Multi Task Models in NeMo ASR. This offers users more flexibility in designing new tasks, and finetuning the models on small amounts of data."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbe2f8eb-204f-4d90-bb0a-a49d994f1ed7",
   "metadata": {
    "id": "cbe2f8eb-204f-4d90-bb0a-a49d994f1ed7"
   },
   "source": [
    "----\n",
    "\n",
    "First, lets instantiate the [Canary](https://huggingface.co/nvidia/canary-1b) model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46c3e5c1-b4f2-4f84-89d6-c77bbe7ebe4f",
   "metadata": {
    "id": "46c3e5c1-b4f2-4f84-89d6-c77bbe7ebe4f"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import nemo.collections.asr as nemo_asr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48b9677b-b1d9-4361-becf-ee84fe8d53ca",
   "metadata": {
    "id": "48b9677b-b1d9-4361-becf-ee84fe8d53ca"
   },
   "outputs": [],
   "source": [
    "model = nemo_asr.models.ASRModel.from_pretrained(\"nvidia/canary-1b\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c0c87c9-5290-4634-9338-818f181c936a",
   "metadata": {
    "id": "6c0c87c9-5290-4634-9338-818f181c936a"
   },
   "source": [
    "# Enable Adapter Support in Model\n",
    "\n",
    "New in NeMo 2.0, we now have a simple utility function to convert the model into one that supports adapters, called `replace_adapter_compatible_modules()`.\n",
    "\n",
    "This will go through the full model and check modules if they support adapters, and then enable that ability. Once used, you can freely use adapter methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfd72316-630b-43c3-9a02-65bb2dabe624",
   "metadata": {
    "id": "bfd72316-630b-43c3-9a02-65bb2dabe624",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.replace_adapter_compatible_modules()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30505bd5-323f-4e90-a941-d0de3f6e55e3",
   "metadata": {
    "id": "30505bd5-323f-4e90-a941-d0de3f6e55e3"
   },
   "source": [
    "## Check Which Targets Are Supported For This Model\n",
    "\n",
    "Now that the model has enabled adapter support, lets take a look at which of its modules support adapter modules to be attached to them.\n",
    "\n",
    "**Note**\n",
    "Below, you might see an adapter module with no name `''` - this corresponds to the \"default\" model target if the target isn't specified. Users can chose to simply skip the module name when adding an adapter, and the model will by default add adapters to the encoder module."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13bcf42e-d33a-4364-8d0f-ab59a26ffa7c",
   "metadata": {
    "id": "13bcf42e-d33a-4364-8d0f-ab59a26ffa7c"
   },
   "outputs": [],
   "source": [
    "model.adapter_module_names"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67324f6a-ffff-47a7-9ee5-dc93819f6ffd",
   "metadata": {
    "id": "67324f6a-ffff-47a7-9ee5-dc93819f6ffd"
   },
   "source": [
    "## Prepare the Adapter\n",
    "\n",
    "Now that we know which modules are supported, lets create a simple adapter module for the encoder and decoder modules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65ec3b2b-3f84-43ed-8a90-085aee383ea6",
   "metadata": {
    "id": "65ec3b2b-3f84-43ed-8a90-085aee383ea6"
   },
   "outputs": [],
   "source": [
    "from nemo.collections.common.parts import LinearAdapterConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47aab832-bfec-4cca-b4ee-868ea1af9869",
   "metadata": {
    "id": "47aab832-bfec-4cca-b4ee-868ea1af9869"
   },
   "outputs": [],
   "source": [
    "input_dim = model.cfg.encoder.d_model\n",
    "adapter_dim = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd519281-ad45-4719-9ad6-561e6192717f",
   "metadata": {
    "id": "cd519281-ad45-4719-9ad6-561e6192717f"
   },
   "outputs": [],
   "source": [
    "enc_adapter_cfg = LinearAdapterConfig(in_features=input_dim, dim=adapter_dim)\n",
    "dec_adapter_cfg = LinearAdapterConfig(in_features=input_dim, dim=adapter_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f147fc89-ab93-4454-ad6b-909288a452a2",
   "metadata": {
    "id": "f147fc89-ab93-4454-ad6b-909288a452a2"
   },
   "source": [
    "## Add Adapter Modules\n",
    "\n",
    "Now that we have the adapter configs prepared, lets add them to the model !\n",
    "\n",
    "We provide the target module by using `target:adapter_name` when calling `add_adapter()` - this tells the model to setup an adapter called `adapter_name` to the module denoted by `target` with the config `cfg`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a23256ce-bc09-4fb0-8c3b-214519b8774b",
   "metadata": {
    "id": "a23256ce-bc09-4fb0-8c3b-214519b8774b"
   },
   "outputs": [],
   "source": [
    "model.add_adapter(name=\"encoder:enc\", cfg=enc_adapter_cfg)\n",
    "model.add_adapter(name=\"transf_decoder:dec\", cfg=dec_adapter_cfg)\n",
    "\n",
    "print(\"Added adapters!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2dbe9b7b-9a3d-4504-a652-1d90701cbbf8",
   "metadata": {
    "id": "2dbe9b7b-9a3d-4504-a652-1d90701cbbf8"
   },
   "source": [
    "## Freeze Original Module Parameters and Unfreeze Adapter Weights Only\n",
    "\n",
    "When tuning adapters, we usually freeze the entire base model and only tune the adapters. This prevents the need for large amounts of data, preserves a lot of memory (since the full model doesnt need backward pass, only the adapters) and makes it easier to adapt huge models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f8162dd-0373-4e65-aa8a-f458a1633578",
   "metadata": {
    "id": "2f8162dd-0373-4e65-aa8a-f458a1633578",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.freeze()\n",
    "model.unfreeze_enabled_adapters()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b3795a4-fcfe-49ee-a76f-1cb77d99ace1",
   "metadata": {
    "id": "0b3795a4-fcfe-49ee-a76f-1cb77d99ace1"
   },
   "source": [
    "----\n",
    "\n",
    "Lets make sure that the number of trainable parameters is a lot smaller (< 1 M) than the total number of params (1 B)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58453f40-d72d-4f9b-a427-3fb63787f3d6",
   "metadata": {
    "id": "58453f40-d72d-4f9b-a427-3fb63787f3d6"
   },
   "outputs": [],
   "source": [
    "model.summarize()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa713f4a-ec16-4e2a-aeb3-ac7c4090f20f",
   "metadata": {
    "id": "aa713f4a-ec16-4e2a-aeb3-ac7c4090f20f"
   },
   "source": [
    "## Check Enabled Adapters\n",
    "\n",
    "Here, we check that the adapters that we named above (`enc` and `dec`) are both setup and enabled."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d69f09d9-411e-420e-8f17-c86391e88fc3",
   "metadata": {
    "id": "d69f09d9-411e-420e-8f17-c86391e88fc3"
   },
   "outputs": [],
   "source": [
    "model.get_enabled_adapters()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f_XpTJx9hQXy",
   "metadata": {
    "id": "f_XpTJx9hQXy"
   },
   "source": [
    "# Customizing Multi Task Models\n",
    "\n",
    "In the following section, we will take a deeper look into what are the components that compose a Multi Task Model and how users can override each of these parts to create their own customizable multi task models.\n",
    "\n",
    "---\n",
    "\n",
    "In this tutorial, we will only see the internal components such as the prompt format and dataset construction, but not change them.\n",
    "\n",
    "In a following tutorial, we will show how to add an additional task to a pre-trained Multi Task Model using a pre-trained model as a starting point."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f0beb8c-7b12-4169-a3f7-1639bdaf6160",
   "metadata": {
    "id": "6f0beb8c-7b12-4169-a3f7-1639bdaf6160"
   },
   "source": [
    "# Prompt Handling for Multi Task Models\n",
    "Nvidia Canary is our first model that is a Multi Task Model.\n",
    "\n",
    "Multi Task models utilize a prompt format, similar to those used in Large Language Models, in order to denote to the model which task is to be performed, which langauge is being spoken and what language should the output transcript be in, whether to provide punctuation and capitalization or not, and so much more in the future !\n",
    "\n",
    "Lets take a look at the model's `prompt` for the Canary model that we have created -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56a78cd0-afaf-4272-898f-d9e13ba871d3",
   "metadata": {
    "id": "56a78cd0-afaf-4272-898f-d9e13ba871d3"
   },
   "outputs": [],
   "source": [
    "model.prompt_format"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9cbaf28a-1f10-4da3-a3ed-53b2239baa49",
   "metadata": {
    "id": "9cbaf28a-1f10-4da3-a3ed-53b2239baa49"
   },
   "source": [
    "----\n",
    "\n",
    "This gives us the prompt format functions name, which we will see below points to a prompt format function that reads in manifest items and maps it to the template."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "087d1f60-3679-4593-840f-8d0fbd8a0e3e",
   "metadata": {
    "id": "087d1f60-3679-4593-840f-8d0fbd8a0e3e"
   },
   "source": [
    "## Reuse / Register a Prompt Format Function\n",
    "\n",
    "When we print `model.prompt_format` it writes `canary` which is one of the registered prompt templates available in NeMo ASR.\n",
    "For simplicity's sake, we will continue to use the same prompt format for this tutorial. However, we enable users to define their own prompt formats and register them as needed.\n",
    "\n",
    "Let's see what the `canary` prompt format looks like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c202abaf-63ca-4475-a2bb-3b487be8e375",
   "metadata": {
    "id": "c202abaf-63ca-4475-a2bb-3b487be8e375"
   },
   "outputs": [],
   "source": [
    "from nemo.collections.common.data.prompt_fn import get_prompt_format_fn, registered_prompt_format_fn\n",
    "from nemo.collections.common.prompts import CanaryPromptFormatter, PromptFormatter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07c56dc3-fe42-49fc-936c-770ec17a29ac",
   "metadata": {
    "id": "07c56dc3-fe42-49fc-936c-770ec17a29ac",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# sample audio data\n",
    "import numpy as np\n",
    "import soundfile as sf\n",
    "from io import BytesIO\n",
    "from lhotse import Recording, SupervisionSegment, CutSet\n",
    "\n",
    "def create_sine_wave(duration: float = 1.0, sample_rate: int = 16000, frequency: float = 440.0):\n",
    "    \"\"\"Generate a sine wave of specified duration and frequency.\"\"\"\n",
    "    t = np.linspace(0, duration, int(duration * sample_rate))\n",
    "    return np.sin(2 * np.pi * frequency * t)\n",
    "\n",
    "audio = create_sine_wave()\n",
    "    \n",
    "    # Convert to 16-bit PCM WAV format in memory\n",
    "buffer = BytesIO()\n",
    "sf.write(buffer, audio, 16000, format='WAV')\n",
    "audio_bytes = buffer.getvalue()\n",
    "\n",
    "# Create a Recording from the bytes\n",
    "cut = Recording.from_bytes(\n",
    "    data=audio_bytes,\n",
    "    recording_id=\"generated_sine\"\n",
    ").to_cut()\n",
    "\n",
    "cut.supervisions = [SupervisionSegment(cut.id, cut.recording.id, start=0, duration=cut.duration, text=\"I said something\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c56dcaf-ac27-4e92-8f56-9b5a7daf0034",
   "metadata": {
    "id": "07c56dc3-fe42-49fc-936c-770ec17a29ac",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "canary_prompt_format_fn = get_prompt_format_fn(cut, CanaryPromptFormatter)\n",
    "canary_prompt_format_fn?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1170b57c-f4c7-432f-91bb-1dbf73063d60",
   "metadata": {
    "id": "1170b57c-f4c7-432f-91bb-1dbf73063d60"
   },
   "source": [
    "### Registering a New Prompt Format Function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d11a8a05-6ba7-41f3-97ab-43453a59c860",
   "metadata": {
    "id": "d11a8a05-6ba7-41f3-97ab-43453a59c860"
   },
   "source": [
    "Just to show that this is user-configurable, we show how to register a dummy prompt format below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f77378ff-d5de-4b86-bfaf-e62b51c7f9ce",
   "metadata": {
    "id": "f77378ff-d5de-4b86-bfaf-e62b51c7f9ce"
   },
   "outputs": [],
   "source": [
    "from nemo.collections.common.prompts import PromptFormatter\n",
    "from lhotse.cut import Cut\n",
    "@registered_prompt_format_fn(Cut, PromptFormatter)\n",
    "def canary_custom(example, formatter):\n",
    "    \"\"\" Users can implement this as needed \"\"\"\n",
    "    raise NotImplementedError()\n",
    "\n",
    "print(\"Registered prompt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb02f068-8fee-46e1-8096-910062668173",
   "metadata": {
    "id": "cb02f068-8fee-46e1-8096-910062668173"
   },
   "outputs": [],
   "source": [
    "temp = get_prompt_format_fn(Cut, PromptFormatter)\n",
    "temp.__name__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f14aa85b-71cb-4813-837b-b28a384685dc",
   "metadata": {
    "id": "f14aa85b-71cb-4813-837b-b28a384685dc"
   },
   "source": [
    "## Create / Reuse a Prompt Format\n",
    "\n",
    "Canary Multi Task Model comes with a pre-defined prompt template, so we need to provide it data in a format that can be handled by that prompt format class.\n",
    "\n",
    "A `PromptFormatter` is a special class that defines the dialog template of the order of turns that occur in a model's prompt. For example, in Language Models, we normally may begin with either a `System` or `User` turn, followed by an `Assistant` turn which produces an output from the model. Similarly in Multi Task models, we enable support for such a usage pattern.\n",
    "\n",
    "Do note: Current generation of Canary models are not trained to operate on multi turn conversations, however future variants of Multi Task models may support such usage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35530cad-84d7-422b-82c5-1bda5c1a4497",
   "metadata": {
    "id": "35530cad-84d7-422b-82c5-1bda5c1a4497",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Let's review the actual prompt formatter clas docs\n",
    "model.prompt?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd0c0d1-da8a-4de6-9efc-86a7dd3ed660",
   "metadata": {
    "id": "0cd0c0d1-da8a-4de6-9efc-86a7dd3ed660"
   },
   "outputs": [],
   "source": [
    "# Let's see the actual template of this prompt formatter\n",
    "model.prompt.TEMPLATE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72956a2f-f051-42d2-9e08-47e954d88e5c",
   "metadata": {
    "id": "72956a2f-f051-42d2-9e08-47e954d88e5c"
   },
   "source": [
    "---\n",
    "\n",
    "We see that the template contains two turns - `user` and `assistant`.\n",
    "\n",
    "User template looks as follows: `<|startoftranscript|>|source_lang||task||target_lang||pnc|`\n",
    "During execution, we remove the `|` in order to fill in the actual value of the slots provided by the the data loader.\n",
    "\n",
    "User holds the following allowed slots -\n",
    "* `source_lang`\n",
    "* `target_lang`\n",
    "* `task`\n",
    "* `pnc`\n",
    "\n",
    "Similarly, for Assistant template : `|text|<|endoftext|>`\n",
    "\n",
    "Assistant holds the following allowed slots -\n",
    "* `text`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "540c04af-34d1-4b46-b935-40b16f54ca03",
   "metadata": {
    "id": "540c04af-34d1-4b46-b935-40b16f54ca03"
   },
   "source": [
    "### Creating and Using a Custom Prompt Formatter\n",
    "\n",
    "While we provide a pre-trained model with a pre-defined prompt format, we also enable users to create their own PromptFormatter subclass and change it as needed.\n",
    "\n",
    "Below, we show a simple modification to the model's PromptFormatter and show how to change it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0adb576c-df58-4b66-b8fa-8e653da6fead",
   "metadata": {
    "id": "0adb576c-df58-4b66-b8fa-8e653da6fead"
   },
   "outputs": [],
   "source": [
    "# Create a new prompt formatter using the original CanaryPromptFormatter class as baseclass\n",
    "class CanaryPromptFormatterV2(model.prompt.__class__):\n",
    "\n",
    "    # make sure to provide a new name\n",
    "    NAME: str = \"canary_custom\"\n",
    "\n",
    "    # Make any changes as necessary.\n",
    "    # For this demonstration, we will not change anything other than the name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7d85683-ddd0-40c5-956d-e14d09243424",
   "metadata": {
    "id": "f7d85683-ddd0-40c5-956d-e14d09243424"
   },
   "outputs": [],
   "source": [
    "# Next, lets update the model's prompt formatter\n",
    "model.change_prompt(\"canary_custom\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6581f934-a55b-41df-864a-351d1fb0029e",
   "metadata": {
    "id": "6581f934-a55b-41df-864a-351d1fb0029e"
   },
   "source": [
    "---\n",
    "\n",
    "We have now successfully changed the prompt format to `canary_custom`.\n",
    "\n",
    "**Note**: It is important to know that when changing the prompt format, the name of the new prompt format class (`canary_custom` in this case) **has to match** the name of the prompt function registered with `@registered_prompt_format_fn`!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1d84948-8f73-4c31-923f-eaf01d877835",
   "metadata": {
    "id": "c1d84948-8f73-4c31-923f-eaf01d877835",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Check if everything is ok -\n",
    "model.prompt.__class__.__name__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f617cda0-d16b-400a-b495-dac213d318e1",
   "metadata": {
    "id": "f617cda0-d16b-400a-b495-dac213d318e1"
   },
   "outputs": [],
   "source": [
    "model.prompt_format"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb964964-e978-43e9-befa-9bb0904db82f",
   "metadata": {
    "id": "cb964964-e978-43e9-befa-9bb0904db82f"
   },
   "source": [
    "---\n",
    "For the rest of the tutorial, we will revert back to the original prompt formatter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "526093a8-86ba-48f0-a60b-55642720fc4e",
   "metadata": {
    "id": "526093a8-86ba-48f0-a60b-55642720fc4e"
   },
   "outputs": [],
   "source": [
    "model.change_prompt('canary')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c4d2986-89b4-4589-ab0e-69683084cfd4",
   "metadata": {
    "id": "9c4d2986-89b4-4589-ab0e-69683084cfd4"
   },
   "source": [
    "## Creating / Using a Multi Task Dataset\n",
    "\n",
    "Now that we have learned how to modify the model's prompt formatter and the underlying format function that maps manifest items into slots to inject into the prompt template, next let's take a look at how to use and create custom datasets for training multi task models.\n",
    "\n",
    "---\n",
    "\n",
    "Unlike previous tutorials that showcase how to use pre-defined datasets and point them to your manifest files, we will take a slightly more hands-on approach for multi task modes. This is due to shear flexibility of multi task models - they can do almost any task that you can formulate into a \"speech in - text out\" problem.\n",
    "\n",
    "So it is not easy to have a pre-defined dataset class that can handle all new ideas and tasks that researchers can come up with.\n",
    "\n",
    "Instead, we showcase how to build a custom dataset for yourself and use it with the Multi Task model instead."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b35ca0c2-8ceb-423f-b9ef-7dd6ec5a6952",
   "metadata": {
    "id": "b35ca0c2-8ceb-423f-b9ef-7dd6ec5a6952"
   },
   "source": [
    "---\n",
    "\n",
    "However, we also provide a base class that can be used as is by users if they dont want the hassle of writing their own datasets.\n",
    "\n",
    "This is handled by the `PromptedAudioToTextLhotseDataset` -  it maps user defined manifest items to the items defined in the prompt template of the model, so as long as the manifest corresponds to the slots supported by the model, it will be managed by the Dataset automatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d35d513-8538-4bcb-b892-898f16ad3f0f",
   "metadata": {
    "id": "3d35d513-8538-4bcb-b892-898f16ad3f0f",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from nemo.collections.asr.data.audio_to_text_lhotse_prompted import PromptedAudioToTextLhotseDataset\n",
    "\n",
    "# Uncomment below line to see the class definition of PromptedAudioToTextLhotseDataset\n",
    "# PromptedAudioToTextLhotseDataset??"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51e3a150-40b9-4599-8c6e-0f01698989b4",
   "metadata": {
    "id": "51e3a150-40b9-4599-8c6e-0f01698989b4"
   },
   "source": [
    "### Creating a New Prompted Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56208452-ea18-44c8-8c71-0daef431dc31",
   "metadata": {
    "id": "56208452-ea18-44c8-8c71-0daef431dc31"
   },
   "outputs": [],
   "source": [
    "import torch.utils.data\n",
    "from lhotse import CutSet\n",
    "from lhotse.cut import MixedCut, MonoCut\n",
    "from lhotse.dataset import AudioSamples\n",
    "from lhotse.dataset.collation import collate_vectors\n",
    "\n",
    "from nemo.collections.asr.data.audio_to_text_lhotse_prompted import PromptedAudioToTextLhotseDataset, PromptedAudioToTextMiniBatch\n",
    "\n",
    "class MyCanaryPromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):\n",
    "    \"\"\"\n",
    "    This dataset is based on :class:`~nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`.\n",
    "    It is a Lhotse-style dataset that converts a mini-batch of Cuts into tensors.\n",
    "    The main difference from ``LhotseSpeechToTextBpeDataset`` is that we introduce\n",
    "    a special prompt format for multitask encoder-decoder models.\n",
    "\n",
    "    To perform the prompt formatting, we accept a ``prompt_format_fn``.\n",
    "    It's expected to accept:\n",
    "    * a ``Cut`` a single MonoCut or MixedCut\n",
    "    * a ``PromptFormatter`` Prepend and append control tokens to the token sequence\n",
    "\n",
    "    Tokenized utterances will be extended with special prompt tokens according to ``prompt_format_fn`` logic.\n",
    "    We support cuts with multiple supervision segments -- their tokenized texts will be concatenated before we add the prompt tokens.\n",
    "    This is useful, for example, in code-switched scenarios where each segment is spoken in a different language.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        tokenizer: 'TokenizerSpec',\n",
    "        prompt: PromptFormatter\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.tokenizer = tokenizer\n",
    "        self.load_audio = AudioSamples(fault_tolerant=True)\n",
    "        self.padding_value = self.tokenizer.pad_id\n",
    "        self.prompt = prompt\n",
    "        self.prompt_format_fn = get_prompt_format_fn(Cut, self.prompt)  # Use the default canary prompt function\n",
    "\n",
    "\n",
    "    def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch:\n",
    "        audio, audio_lens, cuts = self.load_audio(cuts)\n",
    "        answers = []\n",
    "        prompts = []\n",
    "        prompts_with_answers = []\n",
    "\n",
    "        for cut in cuts:\n",
    "            prompted_answers = self.prompt_format_fn(cut, self.prompt)\n",
    "            answers.append(prompted_answers[\"answer_ids\"])\n",
    "            prompts.append(prompted_answers[\"context_ids\"])\n",
    "            prompts_with_answers.append(prompted_answers[\"input_ids\"])\n",
    "        \n",
    "        transcript, transcript_lens = self._collate_tokens(answers)\n",
    "        prompts_with_answers, prompts_with_answers_lens = self._collate_tokens(prompts_with_answers)\n",
    "        prompts, prompt_lens = self._collate_tokens(prompts)\n",
    "\n",
    "        return PromptedAudioToTextMiniBatch(\n",
    "            audio=audio,\n",
    "            audio_lens=audio_lens,\n",
    "            transcript=transcript,\n",
    "            transcript_lens=transcript_lens,\n",
    "            prompt=prompts,\n",
    "            prompt_lens=prompt_lens,\n",
    "            prompted_transcript=prompts_with_answers,\n",
    "            prompted_transcript_lens=prompts_with_answers_lens,\n",
    "            cuts=cuts.drop_in_memory_data(),\n",
    "        )\n",
    "\n",
    "    def _collate_tokens(self, tokens: list[list[int] | torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:\n",
    "        tokens = [torch.as_tensor(t) for t in tokens]\n",
    "        token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)\n",
    "        tokens = collate_vectors(tokens, padding_value=self.padding_value)\n",
    "        return tokens, token_lens\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cb71ba1-ce2e-49c7-8126-be7e7851c812",
   "metadata": {
    "id": "5cb71ba1-ce2e-49c7-8126-be7e7851c812"
   },
   "source": [
    "---\n",
    "\n",
    "The above class is mostly a demonstration, but it showcases how users might flexibly change the prompt formatter, prompt format function and even the data set that handles these two in a flexible way.\n",
    "\n",
    "The order of operations is usually this -\n",
    "\n",
    "1) Create a new Prompt Formatter class - this denotes the slots that each turn can have (including new task inputs or other values). This class is auto registered.\n",
    "2) Create a new Prompt Format function - Using `@registered_prompt_format_fn` decorator, write a custom function that accepts args and processes the provided input data from a manifest.\n",
    "3) Create a new Dataset class (usually based on the `PromptedAudioToTextLhotseDataset` dataset) that uses the Prompt Format function to convert manifest items into nicely formatted samples that can be passed to the Prompt Formatter."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7bf8078-663e-43cb-b045-0c8b6ef08e30",
   "metadata": {
    "id": "a7bf8078-663e-43cb-b045-0c8b6ef08e30"
   },
   "source": [
    "# Preparing a Canary Dataset\n",
    "\n",
    "Now that we have all the pieces together on the model side, let's take a look on the data side."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83c9eabc-0473-463e-be1f-ab6d5f519a79",
   "metadata": {
    "id": "83c9eabc-0473-463e-be1f-ab6d5f519a79"
   },
   "source": [
    "## Required Roles Defined by Prompt Format\n",
    "\n",
    "These are the available 'roles' available in the prompt format - they denote at each turn, one role can be enabled and its input or output can be calculated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11ff9641-53fd-4481-b414-0edc12bf4dc3",
   "metadata": {
    "id": "11ff9641-53fd-4481-b414-0edc12bf4dc3"
   },
   "outputs": [],
   "source": [
    "model.prompt.get_roles()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "203a67e2-74fd-440c-9658-451f41239f36",
   "metadata": {
    "id": "203a67e2-74fd-440c-9658-451f41239f36"
   },
   "outputs": [],
   "source": [
    "for role in model.prompt.get_roles():\n",
    "    print(role, model.prompt.get_slots(role))\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e887f9d-94e7-4843-9da8-f914e24651f3",
   "metadata": {
    "id": "8e887f9d-94e7-4843-9da8-f914e24651f3"
   },
   "source": [
    "## Create a Data Module\n",
    "\n",
    "Data Modules are one way of organizing datasets in PyTorch Lightning. It provides a unified place where data loading and processing can be potentially handled.\n",
    "\n",
    "**Note**: This isn't strictly necessary - you can achieve the same using just Pytorch dataloaders directly and passing it to Trainer.fit() but we showcase a data module codebase that can be extended by the user."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51d58931-4166-4ab9-a755-4c5268001192",
   "metadata": {
    "id": "51d58931-4166-4ab9-a755-4c5268001192"
   },
   "source": [
    "----\n",
    "\n",
    "In our CanaryAN4DataModule - we will perform two tasks. One is En ASR - transcribing the AN4 English dataset. Another is En to De AST - directly translating the english audio to German text.\n",
    "\n",
    "For simplicity's sake, we will use a small off-the-shelf model to perform the translation of English Transcripts to German."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91ed74ca-5d5e-412d-a813-0659014aa9a3",
   "metadata": {
    "id": "91ed74ca-5d5e-412d-a813-0659014aa9a3"
   },
   "source": [
    "---\n",
    "\n",
    "In NeMo 2.0, we utilize [Lhotse](https://github.com/lhotse-speech/lhotse) as our data backbone for speech tasks, which simplifies using custom speech datasets.\n",
    "\n",
    "Most of the magic is handled by the following code\n",
    "\n",
    "```python\n",
    "from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config\n",
    "\n",
    "get_lhotse_dataloader_from_config(\n",
    "    OmegaConf.create(config),  # Pass in a config that points to the manifest files and other arguments\n",
    "    global_rank=self.trainer.global_rank,\n",
    "    world_size=self.trainer.world_size,\n",
    "    # Pass in the dataset class for Lhotse to handle. This class now receives CutSet as input.\n",
    "    dataset=MyCanaryPromptedAudioToTextLhotseDataset(tokenizer=self.tokenizer, prompt=CanaryPromptFormatter(self.tokenizer)),\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a15ab9b-7603-4ac5-890c-92a541a0527c",
   "metadata": {
    "id": "4a15ab9b-7603-4ac5-890c-92a541a0527c"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import json\n",
    "import copy\n",
    "import subprocess\n",
    "import tarfile\n",
    "import wget\n",
    "import librosa\n",
    "import tqdm\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "import lightning.pytorch as L\n",
    "\n",
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest\n",
    "from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config\n",
    "\n",
    "\n",
    "# Function to build a manifest\n",
    "def build_manifest(transcripts_path, manifest_path, wav_path, data_dir):\n",
    "    with open(transcripts_path, 'r') as fin:\n",
    "        with open(manifest_path, 'w') as fout:\n",
    "            for line in fin:\n",
    "                # Lines look like this:\n",
    "                # <s> transcript </s> (fileID)\n",
    "                transcript = line[: line.find('(')-1].lower()\n",
    "                transcript = transcript.replace('<s>', '').replace('</s>', '')\n",
    "                transcript = transcript.strip()\n",
    "\n",
    "                file_id = line[line.find('(')+1 : -2]  # e.g. \"cen4-fash-b\"\n",
    "                audio_path = os.path.join(\n",
    "                    data_dir, wav_path,\n",
    "                    file_id[file_id.find('-')+1 : file_id.rfind('-')],\n",
    "                    file_id + '.wav')\n",
    "\n",
    "                duration = librosa.core.get_duration(path=audio_path)\n",
    "\n",
    "                # Write the metadata to the manifest\n",
    "                metadata = {\n",
    "                    \"audio_filepath\": audio_path,\n",
    "                    \"duration\": duration,\n",
    "                    \"text\": transcript,\n",
    "                    \"pnc\": \"no\",\n",
    "                    \"source_lang\": \"en\",\n",
    "                    \"target_lang\": \"en\",\n",
    "                    \"task\": \"asr\",\n",
    "                }\n",
    "                json.dump(metadata, fout)\n",
    "                fout.write('\\n')\n",
    "\n",
    "    return manifest_path\n",
    "\n",
    "\n",
    "class CanaryAN4DataModule(L.LightningDataModule):\n",
    "\n",
    "    def __init__(self, tokenizer, data_dir: str = \"./an4/\", batch_size=8):\n",
    "        super().__init__()\n",
    "        self.tokenizer = tokenizer\n",
    "        self.data_dir = data_dir\n",
    "        self.batch_size = batch_size\n",
    "\n",
    "        # ASR manifests\n",
    "        self.train_manifest = data_dir + '/an4/train_manifest.json'\n",
    "        self.test_manifest = data_dir + '/an4/test_manifest.json'\n",
    "\n",
    "        # AST manifests\n",
    "        self.ast_train_manifest = data_dir + '/an4/ast_train_manifest.json'\n",
    "        self.ast_test_manifest = data_dir + '/an4/ast_test_manifest.json'\n",
    "\n",
    "        # Combined manifests\n",
    "        self.combined_train_manifest = data_dir + '/an4/combined_train_manifest.json'\n",
    "        self.combined_test_manifest = data_dir + '/an4/combined_test_manifest.json'\n",
    "\n",
    "    def setup(self, stage):\n",
    "        # make assignments here (val/train/test split)\n",
    "        # called on every process in DDP\n",
    "        # Assign train/val datasets for use in dataloaders\n",
    "        pass\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        config = {'manifest_filepath': self.combined_train_manifest, 'batch_size': self.batch_size,\n",
    "                  'num_workers': 4, 'shuffle': True, 'min_duration': 0.3, 'max_duration': 10.0}\n",
    "        return self._setup_dataloader(config)\n",
    "\n",
    "    def val_dataloader(self):\n",
    "        config = {'manifest_filepath': self.combined_test_manifest, 'batch_size': self.batch_size,\n",
    "                  'num_workers': 4, 'shuffle': False, 'min_duration': 0.3, 'max_duration': 10.0}\n",
    "        return self._setup_dataloader(config)\n",
    "\n",
    "    def test_dataloader(self):\n",
    "        config = {'manifest_filepath': self.combined_test_manifest, 'batch_size': self.batch_size,\n",
    "                  'num_workers': 4, 'shuffle': False, 'min_duration': 0.3, 'max_duration': 10.0}\n",
    "        return self._setup_dataloader(config)\n",
    "\n",
    "    def teardown(self, stage):\n",
    "        # clean up after fit or test\n",
    "        # called on every process in DDP\n",
    "        pass\n",
    "\n",
    "    def _setup_dataloader(self, config):\n",
    "        \"\"\"\n",
    "        The main function that creates the data loader using Lhotse's integration with NeMo.\n",
    "        \"\"\"\n",
    "        return get_lhotse_dataloader_from_config(\n",
    "                OmegaConf.create(config),\n",
    "                global_rank=self.trainer.global_rank,\n",
    "                world_size=self.trainer.world_size,\n",
    "                # Note the passing of our custom dataset\n",
    "                dataset=MyCanaryPromptedAudioToTextLhotseDataset(tokenizer=self.tokenizer, prompt=CanaryPromptFormatter(self.tokenizer)),\n",
    "            )\n",
    "\n",
    "    def prepare_data(self):\n",
    "        # download, split, etc...\n",
    "        # only called on 1 GPU/TPU in distributed\n",
    "        if not os.path.exists(self.data_dir):\n",
    "            os.makedirs(self.data_dir)\n",
    "\n",
    "        data_dir = self.data_dir\n",
    "        if not os.path.exists(data_dir + '/an4_sphere.tar.gz'):\n",
    "            an4_url = 'https://dldata-public.s3.us-east-2.amazonaws.com/an4_sphere.tar.gz'\n",
    "            an4_path = wget.download(an4_url, data_dir)\n",
    "            print(f\"Dataset downloaded at: {an4_path}\")\n",
    "        else:\n",
    "            print(\"Tarfile already exists.\")\n",
    "            an4_path = data_dir + '/an4_sphere.tar.gz'\n",
    "\n",
    "        if not os.path.exists(data_dir + '/an4/'):\n",
    "            # Untar and convert .sph to .wav (using sox)\n",
    "            tar = tarfile.open(an4_path)\n",
    "            tar.extractall(path=data_dir)\n",
    "\n",
    "            print(\"Converting .sph to .wav...\")\n",
    "            sph_list = glob.glob(data_dir + '/an4/**/*.sph', recursive=True)\n",
    "            for sph_path in sph_list:\n",
    "                wav_path = sph_path[:-4] + '.wav'\n",
    "                cmd = [\"sox\", sph_path, wav_path]\n",
    "                subprocess.run(cmd)\n",
    "        print(\"Finished conversion.\\n******\")\n",
    "\n",
    "        # Building Manifests\n",
    "        print(\"******\")\n",
    "        train_transcripts = data_dir + '/an4/etc/an4_train.transcription'\n",
    "        train_manifest = self.train_manifest\n",
    "        if not os.path.isfile(train_manifest):\n",
    "            build_manifest(train_transcripts, train_manifest, 'an4/wav/an4_clstk', data_dir)\n",
    "            print(\"Training manifest created.\")\n",
    "\n",
    "        test_transcripts = data_dir + '/an4/etc/an4_test.transcription'\n",
    "        test_manifest = self.test_manifest\n",
    "        if not os.path.isfile(test_manifest):\n",
    "            build_manifest(test_transcripts, test_manifest, 'an4/wav/an4test_clstk', data_dir)\n",
    "            print(\"Test manifest created.\")\n",
    "        print(\"*** Wrote manifests for Eng ***\")\n",
    "\n",
    "        train_manifest_data = read_manifest(self.train_manifest)\n",
    "        test_manifest_data = read_manifest(self.test_manifest)\n",
    "\n",
    "        if not os.path.isfile(self.ast_train_manifest) or not os.path.isfile(self.ast_test_manifest) or not os.path.isfile(self.combined_train_manifest) or not os.path.isfile(self.combined_test_manifest):\n",
    "            tokenizer = T5Tokenizer.from_pretrained(\"google-t5/t5-small\")\n",
    "            t5_model = T5ForConditionalGeneration.from_pretrained(\"google-t5/t5-small\")\n",
    "\n",
    "            if torch.cuda.is_available():\n",
    "                t5_model = t5_model.cuda()\n",
    "\n",
    "            def pipe(text):\n",
    "                if isinstance(text, str):\n",
    "                    text = [text]\n",
    "\n",
    "                prefix = \"translate English to German\"\n",
    "                prompts = [prefix + \": \" + x for x in text]\n",
    "                input_ids = tokenizer(prompts, return_tensors=\"pt\", padding=True, truncation=True).input_ids\n",
    "                input_ids = input_ids.to(t5_model.device)\n",
    "                outputs = t5_model.generate(input_ids, max_new_tokens=64)\n",
    "                return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]\n",
    "\n",
    "            ast_train_manifest_data = copy.deepcopy(train_manifest_data)\n",
    "            ast_test_manifest_data = copy.deepcopy(test_manifest_data)\n",
    "\n",
    "            print(\"Translating train set\")\n",
    "            train_texts = [x['text'] for x in train_manifest_data]\n",
    "            BATCH_SIZE = 32\n",
    "\n",
    "            for i in tqdm.tqdm(range(0, len(train_texts), BATCH_SIZE), total=len(train_texts) // BATCH_SIZE):\n",
    "                batch_texts = train_texts[i:i+BATCH_SIZE]\n",
    "                batch_texts = pipe(batch_texts)\n",
    "                for j, text in enumerate(batch_texts):\n",
    "                    ast_train_manifest_data[i+j]['text'] = text\n",
    "                    ast_train_manifest_data[i+j]['task'] = 'ast'\n",
    "                    ast_train_manifest_data[i+j]['target_lang'] = 'de'\n",
    "\n",
    "            print(\"Translating test set\")\n",
    "            for data in tqdm.tqdm(ast_test_manifest_data, total=len(ast_test_manifest_data)):\n",
    "                data['text'] = pipe(data['text'])[0]\n",
    "                data['task'] = 'ast'\n",
    "                data['target_lang'] = 'de'\n",
    "\n",
    "            write_manifest(self.ast_train_manifest, ast_train_manifest_data)\n",
    "            write_manifest(self.ast_test_manifest, ast_test_manifest_data)\n",
    "\n",
    "            print(\"*** Wrote ast manifests ***\")\n",
    "\n",
    "            combined_train, combined_test = [], []\n",
    "            combined_train.extend(train_manifest_data)\n",
    "            combined_train.extend(ast_train_manifest_data)\n",
    "\n",
    "            combined_test.extend(test_manifest_data)\n",
    "            combined_test.extend(ast_test_manifest_data)\n",
    "\n",
    "            write_manifest(self.combined_train_manifest, combined_train)\n",
    "            write_manifest(self.combined_test_manifest, combined_test)\n",
    "            print(\"*** Wrote combined manifests ***\")\n",
    "\n",
    "        else:\n",
    "            print(\"*** Wrote ast and combined manifests ***\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e06e697d-7dc2-489f-a52f-195946bfbf6e",
   "metadata": {
    "id": "e06e697d-7dc2-489f-a52f-195946bfbf6e"
   },
   "source": [
    "---\n",
    "\n",
    "Each item in the prepared manifest has the following items by default.\n",
    "\n",
    "As you will recognize, these are the same keys provided by the `CanaryPromptFormatter` classes `slots` argument, so each of these values in the is mapped back to those slots.\n",
    "\n",
    "```python\n",
    "metadata = {\n",
    "    \"audio_filepath\": audio_path,\n",
    "    \"duration\": duration,\n",
    "    \"text\": transcript,\n",
    "    \"pnc\": \"no\",\n",
    "    \"source_lang\": \"en\",\n",
    "    \"target_lang\": \"en\",\n",
    "    \"task\": \"asr\",\n",
    "}\n",
    "```\n",
    "\n",
    "The most important function in the Data Module above is `prepare_data()`:\n",
    "\n",
    "1) It first downloads and converts the AN4 audio files to wav files.\n",
    "2) Then it writes a new manifest file with the above keys for ASR task\n",
    "3) It then translates the En transcripts with a `t5-small` model to generate German transcripts\n",
    "4) Finally it writes another manifest for the AST task with these translated texts.\n",
    "5) Finally it builds a combined manifest item for both ASR (en) and AST (en to de) multi-task training\n",
    "\n",
    "**Note**: We are using prepare_data() only for demonstration. Normally, users should process before experimentation, and so they would only need to implement methods above prepare_data() in their Data Module."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "739f0141-1e0e-4db7-b1f6-9d13589bf50c",
   "metadata": {
    "id": "739f0141-1e0e-4db7-b1f6-9d13589bf50c"
   },
   "source": [
    "## Download and Prepare Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "323287f1-9a44-49ab-8438-dcbf34bf2ebe",
   "metadata": {
    "id": "323287f1-9a44-49ab-8438-dcbf34bf2ebe"
   },
   "outputs": [],
   "source": [
    "data_module = CanaryAN4DataModule(tokenizer=model.tokenizer, batch_size=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "123faf0d-05b2-4f12-850f-350a175ba7c1",
   "metadata": {
    "id": "123faf0d-05b2-4f12-850f-350a175ba7c1",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_module.prepare_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbec085b-9600-49bd-8739-73e5e8e3773f",
   "metadata": {
    "id": "fbec085b-9600-49bd-8739-73e5e8e3773f"
   },
   "outputs": [],
   "source": [
    "!head -n 5 {data_module.train_manifest}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66bad9ac-3bad-4d84-8b30-830856c06804",
   "metadata": {
    "id": "66bad9ac-3bad-4d84-8b30-830856c06804"
   },
   "outputs": [],
   "source": [
    "!head -n 5 {data_module.ast_train_manifest}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cde19c46-e78c-4d7c-adbf-f1559c9203e1",
   "metadata": {
    "id": "cde19c46-e78c-4d7c-adbf-f1559c9203e1"
   },
   "source": [
    "# Evaluate Model before Training\n",
    "\n",
    "Canary Multi Task model is already very capable, achieving strong scores on multiple benchmarks. So we first evaluate the baseline numbers on the two tasks\n",
    "\n",
    "1) ASR: WER calculation on transcripts\n",
    "\n",
    "2) AST: SacreBLEU calculation on translations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb4588b4-7d52-4c4e-bb81-2bcb5a227afd",
   "metadata": {
    "id": "eb4588b4-7d52-4c4e-bb81-2bcb5a227afd"
   },
   "outputs": [],
   "source": [
    "from nemo.collections.asr.metrics.wer import word_error_rate\n",
    "from torchmetrics.text import SacreBLEUScore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1c71044-3cb3-453c-bfcd-ee551cecdddf",
   "metadata": {
    "id": "a1c71044-3cb3-453c-bfcd-ee551cecdddf"
   },
   "outputs": [],
   "source": [
    "asr_test = read_manifest(data_module.test_manifest)\n",
    "ast_test = read_manifest(data_module.ast_test_manifest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1d8acd2-aa08-4ba0-b0c6-c5d662243b00",
   "metadata": {
    "id": "f1d8acd2-aa08-4ba0-b0c6-c5d662243b00"
   },
   "outputs": [],
   "source": [
    "asr_filepaths = [x['audio_filepath'] for x in asr_test]\n",
    "asr_gt = [x['text'] for x in asr_test]\n",
    "\n",
    "ast_filepaths = [x['audio_filepath'] for x in ast_test]\n",
    "ast_gt = [x['text'] for x in ast_test]\n",
    "\n",
    "print(\"Num files:\", len(asr_filepaths))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85ace700-97bf-4697-8e1a-5793eb21e678",
   "metadata": {
    "id": "85ace700-97bf-4697-8e1a-5793eb21e678"
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()  # move model to gpu\n",
    "    model = model.to(torch.bfloat16)  # cast full model to bfloat16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00f2607a-2f67-47fe-9903-0adae4d9adf5",
   "metadata": {
    "id": "00f2607a-2f67-47fe-9903-0adae4d9adf5"
   },
   "outputs": [],
   "source": [
    "asr_preds = model.transcribe(asr_filepaths, pnc='no', task='asr', source_lang='en', target_lang='en', batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eea5ab20-60d4-4e19-87fb-71f6835941e8",
   "metadata": {
    "id": "eea5ab20-60d4-4e19-87fb-71f6835941e8"
   },
   "outputs": [],
   "source": [
    "ast_preds = model.transcribe(ast_filepaths, pnc='no', task='ast', source_lang='en', target_lang='de', batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69e5bb54-5193-4268-98e1-dc6daae8f6eb",
   "metadata": {
    "id": "69e5bb54-5193-4268-98e1-dc6daae8f6eb"
   },
   "outputs": [],
   "source": [
    "wer = word_error_rate([p.text for p in asr_preds], asr_gt)\n",
    "print(\"WER\", wer)\n",
    "\n",
    "sacrebleu = SacreBLEUScore(n_gram=4)\n",
    "scores = []\n",
    "preds = []\n",
    "gts = []\n",
    "for pred, gt in zip(ast_preds, ast_gt):\n",
    "    preds.append(pred)\n",
    "    gts.append([gt])\n",
    "\n",
    "# bleu = sum(scores) / len(scores)\n",
    "sacrebleu.update([p.text for p in preds], gts)\n",
    "bleu = sacrebleu.compute()\n",
    "print(\"BLEU\", bleu.item() * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ee530c9-36a3-47d2-83b9-b2a64080c0eb",
   "metadata": {
    "id": "5ee530c9-36a3-47d2-83b9-b2a64080c0eb"
   },
   "source": [
    "# Train Model\n",
    "\n",
    "Finally, now that adapters have been prepared, model has been evaluated for a baseline and the dataset is prepared, it's time to train the adapter weights on the new datasets.\n",
    "\n",
    "---\n",
    "\n",
    "First, we update the optimizer and scheduler config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0a40461-d739-436c-967a-1a0f8a3ad197",
   "metadata": {
    "id": "d0a40461-d739-436c-967a-1a0f8a3ad197"
   },
   "outputs": [],
   "source": [
    "print(OmegaConf.to_yaml(model.cfg.optim))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ba5811a-fc42-4de5-add5-0d26d1c84219",
   "metadata": {
    "id": "4ba5811a-fc42-4de5-add5-0d26d1c84219"
   },
   "outputs": [],
   "source": [
    "# Setup optimization\n",
    "model.cfg.optim.lr = 3e-4\n",
    "model.cfg.optim.sched.warmup_steps = 25"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1de270a-d1cb-4080-b571-7acf365d7b99",
   "metadata": {
    "id": "d1de270a-d1cb-4080-b571-7acf365d7b99"
   },
   "source": [
    "---\n",
    "\n",
    "Next, we setup a Lightning Trainer and Experiment Manager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9e34369-21ec-41bf-beae-30b60ab46c14",
   "metadata": {
    "id": "b9e34369-21ec-41bf-beae-30b60ab46c14"
   },
   "outputs": [],
   "source": [
    "from omegaconf import OmegaConf\n",
    "from nemo.utils import exp_manager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46f74863-a34d-4ad0-9d8e-3337ea5edd63",
   "metadata": {
    "id": "46f74863-a34d-4ad0-9d8e-3337ea5edd63"
   },
   "outputs": [],
   "source": [
    "trainer = L.Trainer(max_steps=200, accumulate_grad_batches=1, logger=False, enable_checkpointing=False, check_val_every_n_epoch=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "414d7887-bed5-46a2-bfe1-8349db1e6b5b",
   "metadata": {
    "id": "414d7887-bed5-46a2-bfe1-8349db1e6b5b"
   },
   "outputs": [],
   "source": [
    "# # Environment variable generally used for multi-node multi-gpu training.\n",
    "# # In notebook environments, this flag is unnecessary and can cause logs of multiple training runs to overwrite each other.\n",
    "# os.environ.pop('NEMO_EXPM_VERSION', None)\n",
    "\n",
    "# config = exp_manager.ExpManagerConfig(\n",
    "#     exp_dir=f'experiments/canary/',\n",
    "#     name=f\"Canary-Model-Adapter-Training\",\n",
    "#     checkpoint_callback_params=exp_manager.CallbackParams(\n",
    "#         monitor=\"val_wer\",\n",
    "#         mode=\"min\",\n",
    "#         always_save_nemo=False,\n",
    "#         save_best_model=False,\n",
    "#     ),\n",
    "# )\n",
    "\n",
    "# config = OmegaConf.structured(config)\n",
    "\n",
    "# logdir = exp_manager.exp_manager(trainer, config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60769859-8ed5-4f9c-b93a-a6875c7c1c73",
   "metadata": {
    "id": "60769859-8ed5-4f9c-b93a-a6875c7c1c73"
   },
   "source": [
    "---\n",
    "\n",
    "Begin training !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2adb8607-a011-440d-bfa8-976c2871e8ef",
   "metadata": {
    "id": "2adb8607-a011-440d-bfa8-976c2871e8ef"
   },
   "outputs": [],
   "source": [
    "trainer.fit(model, data_module)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "MImbKiqQ6ng-",
   "metadata": {
    "id": "MImbKiqQ6ng-"
   },
   "source": [
    "---\n",
    "\n",
    "Save just the adapter parameters - which is less than 2 MB !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "-akTdyGM6gum",
   "metadata": {
    "id": "-akTdyGM6gum"
   },
   "outputs": [],
   "source": [
    "model.save_adapters(\"adapters.pt\")\n",
    "!ls -l -- *.pt\n",
    "!du -sh *.pt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2525bec5-c42b-48c1-b03c-e8126c346238",
   "metadata": {
    "id": "2525bec5-c42b-48c1-b03c-e8126c346238"
   },
   "source": [
    "# Evaluate after Adaptation\n",
    "\n",
    "Now that the model is done training, lets evaluate its scores on the test set again.\n",
    "We should see a markedly higher translation BLEU and lower WER from above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6edb5528-b1b6-4505-8cdc-ee68c715415e",
   "metadata": {
    "id": "6edb5528-b1b6-4505-8cdc-ee68c715415e"
   },
   "outputs": [],
   "source": [
    "asr_test = read_manifest(data_module.test_manifest)\n",
    "ast_test = read_manifest(data_module.ast_test_manifest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "384aa5f2-89d5-4080-a717-4d65776fae6b",
   "metadata": {
    "id": "384aa5f2-89d5-4080-a717-4d65776fae6b"
   },
   "outputs": [],
   "source": [
    "asr_filepaths = [x['audio_filepath'] for x in asr_test]\n",
    "asr_gt = [x['text'] for x in asr_test]\n",
    "\n",
    "ast_filepaths = [x['audio_filepath'] for x in ast_test]\n",
    "ast_gt = [x['text'] for x in ast_test]\n",
    "\n",
    "print(\"Num files:\", len(asr_filepaths))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48ce5b4c-d349-4d86-ad3c-ee930bb569ee",
   "metadata": {
    "id": "48ce5b4c-d349-4d86-ad3c-ee930bb569ee"
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()\n",
    "    model = model.to(torch.bfloat16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49a37806-286e-4954-8f27-3829cf61d755",
   "metadata": {
    "id": "49a37806-286e-4954-8f27-3829cf61d755"
   },
   "outputs": [],
   "source": [
    "asr_preds = model.transcribe(asr_filepaths, pnc='no', task='asr', source_lang='en', target_lang='en', batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b701e014-2f71-487c-9300-a3ea89a43a45",
   "metadata": {
    "id": "b701e014-2f71-487c-9300-a3ea89a43a45"
   },
   "outputs": [],
   "source": [
    "ast_preds = model.transcribe(ast_filepaths, pnc='no', task='ast', source_lang='en', target_lang='de', batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "087054e5-c511-4094-a115-faf4a3b49d51",
   "metadata": {
    "id": "087054e5-c511-4094-a115-faf4a3b49d51"
   },
   "outputs": [],
   "source": [
    "from nemo.collections.asr.metrics.wer import word_error_rate\n",
    "from torchmetrics.text import SacreBLEUScore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef938f8f-b2db-45f6-9b30-4b3bbce2423f",
   "metadata": {
    "id": "ef938f8f-b2db-45f6-9b30-4b3bbce2423f"
   },
   "outputs": [],
   "source": [
    "wer = word_error_rate([p.text for p in asr_preds], asr_gt)\n",
    "print(\"WER\", wer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a7c2820-d394-4627-8438-0d810d89b72d",
   "metadata": {
    "id": "5a7c2820-d394-4627-8438-0d810d89b72d"
   },
   "outputs": [],
   "source": [
    "sacrebleu = SacreBLEUScore(n_gram=4)\n",
    "scores = []\n",
    "preds = []\n",
    "gts = []\n",
    "for pred, gt in zip(ast_preds, ast_gt):\n",
    "    preds.append(pred)\n",
    "    gts.append([gt])\n",
    "\n",
    "# bleu = sum(scores) / len(scores)\n",
    "sacrebleu.update([p.text for p in preds], gts)\n",
    "bleu = sacrebleu.compute()\n",
    "print(\"BLEU\", bleu.item() * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "521df0e6-1d3c-4709-a080-63638315c514",
   "metadata": {
    "id": "521df0e6-1d3c-4709-a080-63638315c514"
   },
   "source": [
    "# Conclusion\n",
    "\n",
    "In this tutorial we added adapters to a Multi Task model (Nvidia Canary) and show how to create a custom dataset to finetune a canary model to a new dataset with previous tasks such as ASR and AST. The primary goal of this tutorial was to show how to flexibly adapt a Canary model to any of the pre-existing tasks.\n",
    "\n",
    "In a future tutorial, we will show how to add additional tasks to a pre-trained Canary, so that you can leverage the pre-trained encoder and decoder for your own custom tasks!"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
